diff options
author | Nicholas Leonard <nick@nikopia.org> | 2015-07-25 22:51:29 +0300 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2015-07-25 22:51:29 +0300 |
commit | 73ddbe49b8a58eca739e3f442bcb101b6aff0978 (patch) | |
tree | cb91414d9c918f3b5d6e865f882d0529f4f8e82d | |
parent | 3c13b0d41281990e3f1d4dd8a1ae45b8d5ca0f40 (diff) |
[Parallel,Multi]Criterion table inputs
-rw-r--r-- | MultiCriterion.lua | 6 | ||||
-rw-r--r-- | ParallelCriterion.lua | 6 | ||||
-rw-r--r-- | test.lua | 45 | ||||
-rw-r--r-- | utils.lua | 21 |
4 files changed, 71 insertions, 7 deletions
diff --git a/MultiCriterion.lua b/MultiCriterion.lua index 4ad5a73..801ad27 100644 --- a/MultiCriterion.lua +++ b/MultiCriterion.lua @@ -23,10 +23,10 @@ function MultiCriterion:updateOutput(input, target) end function MultiCriterion:updateGradInput(input, target) - self.gradInput:resizeAs(input) - self.gradInput:zero() + self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input) + nn.utils.recursiveFill(self.gradInput, 0) for i=1,#self.criterions do - self.gradInput:add(self.weights[i], self.criterions[i]:updateGradInput(input, target)) + nn.utils.recursiveAdd(self.gradInput, self.weights[i], self.criterions[i]:updateGradInput(input, target)) end return self.gradInput end diff --git a/ParallelCriterion.lua b/ParallelCriterion.lua index 95bd6cc..84d4ee1 100644 --- a/ParallelCriterion.lua +++ b/ParallelCriterion.lua @@ -25,11 +25,11 @@ function ParallelCriterion:updateOutput(input, target) end function ParallelCriterion:updateGradInput(input, target) + self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input) + nn.utils.recursiveFill(self.gradInput, 0) for i,criterion in ipairs(self.criterions) do local target = self.repeatTarget and target or target[i] - self.gradInput[i] = self.gradInput[i] or input[i].new() - self.gradInput[i]:resizeAs(input[i]):zero() - self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target)) + nn.utils.recursiveAdd(self.gradInput[i], self.weights[i], criterion:updateGradInput(input[i], target)) end return self.gradInput end @@ -847,10 +847,11 @@ function nntest.ParallelCriterion() local output = pc:forward(input, target) local output2 = nll:forward(input[1], target[1])/2 + mse:forward(input[2], target[2]) mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion forward error") - local gradInput = pc:backward(input, target) local gradInput2 = {nll:backward(input[1], target[1]):clone():div(2), mse:backward(input[2], target[2])} + local gradInput = pc:backward(input, target) mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion backward error 1") mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion backward error 2") + -- test type pc:float() gradInput[1], gradInput[2] = gradInput[1]:clone(), gradInput[2]:clone() @@ -861,6 +862,7 @@ function nntest.ParallelCriterion() mytester:assert(math.abs(output3 - output) < 0.00001, "ParallelCriterion forward error type") mytester:assertTensorEq(gradInput[1]:float(), gradInput3[1], 0.000001, "ParallelCriterion backward error 1 type") mytester:assertTensorEq(gradInput[2]:float(), gradInput3[2], 0.000001, "ParallelCriterion backward error 2 type") + -- test repeatTarget local input = {torch.rand(2,10), torch.randn(2,10)} local target = torch.randn(2,10) @@ -873,6 +875,26 @@ function nntest.ParallelCriterion() local gradInput2 = {mse:backward(input[1], target):clone():div(2), mse:backward(input[2], target)} mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion repeatTarget backward error 1") mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion repeatTarget backward error 2") + + -- table input + local input = {torch.randn(2,10), {torch.rand(2,10), torch.randn(2,10)}} + local target = {torch.IntTensor{2,5}, {torch.IntTensor{1,8}, torch.randn(2,10)}} + local nll2 = nn.ClassNLLCriterion() + local nll = nn.ClassNLLCriterion() + local mse = nn.MSECriterion() + local pc = nn.ParallelCriterion():add(nll, 0.5):add(mse) + local pc2 = nn.ParallelCriterion():add(nll2, 0.4):add(pc) + local output = pc2:forward(input, target) + local output2 = nll2:forward(input[1], target[1])*0.4 + nll:forward(input[2][1], target[2][1])/2 + mse:forward(input[2][2], target[2][2]) + mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion table forward error") + local gradInput2 = { + nll2:backward(input[1], target[1]):clone():mul(0.4), + {nll:backward(input[2][2], target[2][1]):clone():div(2), mse:backward(input[2][2], target[2][2])} + } + local gradInput = pc2:backward(input, target) + mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion table backward error 1") + mytester:assertTensorEq(gradInput[2][1], gradInput2[2][1], 0.000001, "ParallelCriterion table backward error 2") + mytester:assertTensorEq(gradInput[2][2], gradInput2[2][2], 0.000001, "ParallelCriterion table backward error 3") end function nntest.MultiCriterion() @@ -887,6 +909,7 @@ function nntest.MultiCriterion() local gradInput = mc:backward(input, target) local gradInput2 = nll:backward(input, target):clone():div(2):add(nll2:backward(input, target)) mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "MultiCriterion backward error ") + -- test type mc:float() gradInput = gradInput:clone() @@ -896,6 +919,26 @@ function nntest.MultiCriterion() local gradInput3 = mc:backward(input3, target3) mytester:assert(math.abs(output3 - output) < 0.00001, "MultiCriterion forward error type") mytester:assertTensorEq(gradInput:float(), gradInput3, 0.000001, "MultiCriterion backward error type") + + -- test table input + mc:double() + local input = {torch.randn(2,10), {torch.randn(2,10), torch.randn(2,10)}} + local target = {torch.IntTensor{1,8}, {torch.IntTensor{5,6}, torch.IntTensor{4,3}}} + local pnllc = nn.ParallelCriterion():add(nll):add(nn.ParallelCriterion():add(nll:clone()):add(nll:clone())) + local pnllc2 = nn.ParallelCriterion():add(nll2):add(nn.ParallelCriterion():add(nll2:clone()):add(nll2:clone())) + local mc = nn.MultiCriterion():add(pnllc, 0.5):add(pnllc2) + local output = mc:forward(input, target) + local output2 = pnllc:forward(input, target)/2 + pnllc2:forward(input, target) + mytester:assert(math.abs(output2 - output) < 0.00001, "MultiCriterion forward table error") + local gradInput = mc:backward(input, target) + local gradInput2 = pnllc:clone():backward(input, target) + local gradInput2b = pnllc2:backward(input, target) + gradInput2[1]:div(2):add(gradInput2b[1]) + gradInput2[2][1]:div(2):add(gradInput2b[2][1]) + gradInput2[2][2]:div(2):add(gradInput2b[2][2]) + mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "MultiCriterion backward table 1 error ") + mytester:assertTensorEq(gradInput[2][1], gradInput2[2][1], 0.000001, "MultiCriterion backward table 2 error ") + mytester:assertTensorEq(gradInput[2][2], gradInput2[2][2], 0.000001, "MultiCriterion backward table 3 error ") end function nntest.WeightedMSECriterion() @@ -44,6 +44,27 @@ function nn.utils.recursiveFill(t2, val) return t2 end +function nn.utils.recursiveAdd(t1, val, t2) + if not t2 then + assert(val, "expecting at least two arguments") + t2 = val + val = 1 + end + val = val or 1 + if torch.type(t2) == 'table' then + t1 = (torch.type(t1) == 'table') and t1 or {t1} + for key,_ in pairs(t2) do + t1[key], t2[key] = nn.utils.recursiveAdd(t1[key], val, t2[key]) + end + elseif torch.isTensor(t2) and torch.isTensor(t2) then + t1:add(val, t2) + else + error("expecting nested tensors or tables. Got ".. + torch.type(t1).." and "..torch.type(t2).." instead") + end + return t1, t2 +end + function nn.utils.addSingletonDimension(t, dim) local view = t.new() local size = torch.LongStorage(t:dim() + 1) |