Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2015-07-25 22:51:29 +0300
committerNicholas Leonard <nick@nikopia.org>2015-07-25 22:51:29 +0300
commit73ddbe49b8a58eca739e3f442bcb101b6aff0978 (patch)
treecb91414d9c918f3b5d6e865f882d0529f4f8e82d
parent3c13b0d41281990e3f1d4dd8a1ae45b8d5ca0f40 (diff)
[Parallel,Multi]Criterion table inputs
-rw-r--r--MultiCriterion.lua6
-rw-r--r--ParallelCriterion.lua6
-rw-r--r--test.lua45
-rw-r--r--utils.lua21
4 files changed, 71 insertions, 7 deletions
diff --git a/MultiCriterion.lua b/MultiCriterion.lua
index 4ad5a73..801ad27 100644
--- a/MultiCriterion.lua
+++ b/MultiCriterion.lua
@@ -23,10 +23,10 @@ function MultiCriterion:updateOutput(input, target)
end
function MultiCriterion:updateGradInput(input, target)
- self.gradInput:resizeAs(input)
- self.gradInput:zero()
+ self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input)
+ nn.utils.recursiveFill(self.gradInput, 0)
for i=1,#self.criterions do
- self.gradInput:add(self.weights[i], self.criterions[i]:updateGradInput(input, target))
+ nn.utils.recursiveAdd(self.gradInput, self.weights[i], self.criterions[i]:updateGradInput(input, target))
end
return self.gradInput
end
diff --git a/ParallelCriterion.lua b/ParallelCriterion.lua
index 95bd6cc..84d4ee1 100644
--- a/ParallelCriterion.lua
+++ b/ParallelCriterion.lua
@@ -25,11 +25,11 @@ function ParallelCriterion:updateOutput(input, target)
end
function ParallelCriterion:updateGradInput(input, target)
+ self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input)
+ nn.utils.recursiveFill(self.gradInput, 0)
for i,criterion in ipairs(self.criterions) do
local target = self.repeatTarget and target or target[i]
- self.gradInput[i] = self.gradInput[i] or input[i].new()
- self.gradInput[i]:resizeAs(input[i]):zero()
- self.gradInput[i]:add(self.weights[i], criterion:updateGradInput(input[i],target))
+ nn.utils.recursiveAdd(self.gradInput[i], self.weights[i], criterion:updateGradInput(input[i], target))
end
return self.gradInput
end
diff --git a/test.lua b/test.lua
index cd48433..e198227 100644
--- a/test.lua
+++ b/test.lua
@@ -847,10 +847,11 @@ function nntest.ParallelCriterion()
local output = pc:forward(input, target)
local output2 = nll:forward(input[1], target[1])/2 + mse:forward(input[2], target[2])
mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion forward error")
- local gradInput = pc:backward(input, target)
local gradInput2 = {nll:backward(input[1], target[1]):clone():div(2), mse:backward(input[2], target[2])}
+ local gradInput = pc:backward(input, target)
mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion backward error 1")
mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion backward error 2")
+
-- test type
pc:float()
gradInput[1], gradInput[2] = gradInput[1]:clone(), gradInput[2]:clone()
@@ -861,6 +862,7 @@ function nntest.ParallelCriterion()
mytester:assert(math.abs(output3 - output) < 0.00001, "ParallelCriterion forward error type")
mytester:assertTensorEq(gradInput[1]:float(), gradInput3[1], 0.000001, "ParallelCriterion backward error 1 type")
mytester:assertTensorEq(gradInput[2]:float(), gradInput3[2], 0.000001, "ParallelCriterion backward error 2 type")
+
-- test repeatTarget
local input = {torch.rand(2,10), torch.randn(2,10)}
local target = torch.randn(2,10)
@@ -873,6 +875,26 @@ function nntest.ParallelCriterion()
local gradInput2 = {mse:backward(input[1], target):clone():div(2), mse:backward(input[2], target)}
mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion repeatTarget backward error 1")
mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion repeatTarget backward error 2")
+
+ -- table input
+ local input = {torch.randn(2,10), {torch.rand(2,10), torch.randn(2,10)}}
+ local target = {torch.IntTensor{2,5}, {torch.IntTensor{1,8}, torch.randn(2,10)}}
+ local nll2 = nn.ClassNLLCriterion()
+ local nll = nn.ClassNLLCriterion()
+ local mse = nn.MSECriterion()
+ local pc = nn.ParallelCriterion():add(nll, 0.5):add(mse)
+ local pc2 = nn.ParallelCriterion():add(nll2, 0.4):add(pc)
+ local output = pc2:forward(input, target)
+ local output2 = nll2:forward(input[1], target[1])*0.4 + nll:forward(input[2][1], target[2][1])/2 + mse:forward(input[2][2], target[2][2])
+ mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion table forward error")
+ local gradInput2 = {
+ nll2:backward(input[1], target[1]):clone():mul(0.4),
+ {nll:backward(input[2][2], target[2][1]):clone():div(2), mse:backward(input[2][2], target[2][2])}
+ }
+ local gradInput = pc2:backward(input, target)
+ mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion table backward error 1")
+ mytester:assertTensorEq(gradInput[2][1], gradInput2[2][1], 0.000001, "ParallelCriterion table backward error 2")
+ mytester:assertTensorEq(gradInput[2][2], gradInput2[2][2], 0.000001, "ParallelCriterion table backward error 3")
end
function nntest.MultiCriterion()
@@ -887,6 +909,7 @@ function nntest.MultiCriterion()
local gradInput = mc:backward(input, target)
local gradInput2 = nll:backward(input, target):clone():div(2):add(nll2:backward(input, target))
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "MultiCriterion backward error ")
+
-- test type
mc:float()
gradInput = gradInput:clone()
@@ -896,6 +919,26 @@ function nntest.MultiCriterion()
local gradInput3 = mc:backward(input3, target3)
mytester:assert(math.abs(output3 - output) < 0.00001, "MultiCriterion forward error type")
mytester:assertTensorEq(gradInput:float(), gradInput3, 0.000001, "MultiCriterion backward error type")
+
+ -- test table input
+ mc:double()
+ local input = {torch.randn(2,10), {torch.randn(2,10), torch.randn(2,10)}}
+ local target = {torch.IntTensor{1,8}, {torch.IntTensor{5,6}, torch.IntTensor{4,3}}}
+ local pnllc = nn.ParallelCriterion():add(nll):add(nn.ParallelCriterion():add(nll:clone()):add(nll:clone()))
+ local pnllc2 = nn.ParallelCriterion():add(nll2):add(nn.ParallelCriterion():add(nll2:clone()):add(nll2:clone()))
+ local mc = nn.MultiCriterion():add(pnllc, 0.5):add(pnllc2)
+ local output = mc:forward(input, target)
+ local output2 = pnllc:forward(input, target)/2 + pnllc2:forward(input, target)
+ mytester:assert(math.abs(output2 - output) < 0.00001, "MultiCriterion forward table error")
+ local gradInput = mc:backward(input, target)
+ local gradInput2 = pnllc:clone():backward(input, target)
+ local gradInput2b = pnllc2:backward(input, target)
+ gradInput2[1]:div(2):add(gradInput2b[1])
+ gradInput2[2][1]:div(2):add(gradInput2b[2][1])
+ gradInput2[2][2]:div(2):add(gradInput2b[2][2])
+ mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "MultiCriterion backward table 1 error ")
+ mytester:assertTensorEq(gradInput[2][1], gradInput2[2][1], 0.000001, "MultiCriterion backward table 2 error ")
+ mytester:assertTensorEq(gradInput[2][2], gradInput2[2][2], 0.000001, "MultiCriterion backward table 3 error ")
end
function nntest.WeightedMSECriterion()
diff --git a/utils.lua b/utils.lua
index 78a22a2..74ff7e6 100644
--- a/utils.lua
+++ b/utils.lua
@@ -44,6 +44,27 @@ function nn.utils.recursiveFill(t2, val)
return t2
end
+function nn.utils.recursiveAdd(t1, val, t2)
+ if not t2 then
+ assert(val, "expecting at least two arguments")
+ t2 = val
+ val = 1
+ end
+ val = val or 1
+ if torch.type(t2) == 'table' then
+ t1 = (torch.type(t1) == 'table') and t1 or {t1}
+ for key,_ in pairs(t2) do
+ t1[key], t2[key] = nn.utils.recursiveAdd(t1[key], val, t2[key])
+ end
+ elseif torch.isTensor(t2) and torch.isTensor(t2) then
+ t1:add(val, t2)
+ else
+ error("expecting nested tensors or tables. Got "..
+ torch.type(t1).." and "..torch.type(t2).." instead")
+ end
+ return t1, t2
+end
+
function nn.utils.addSingletonDimension(t, dim)
local view = t.new()
local size = torch.LongStorage(t:dim() + 1)