Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2015-04-29 00:40:07 +0300
committerNicholas Leonard <nick@nikopia.org>2015-05-05 21:58:49 +0300
commit3fc211c48b5ad270f374cb18733b0354cc4b9335 (patch)
treec5084b7f3363721d2b90440c2c1ff979a6f72c1b /test.lua
parent59d56be820159c8620062a95182cada6e2384ffc (diff)
ParallelCriterion
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua60
1 files changed, 60 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index a28674a..9414a66 100644
--- a/test.lua
+++ b/test.lua
@@ -811,6 +811,66 @@ function nntest.MarginRankingCriterion()
mytester:assert(torch.type(gradInput2[2]) == 'torch.FloatTensor', "MRC:type() error 2")
end
+function nntest.ParallelCriterion()
+ local input = {torch.rand(2,10), torch.randn(2,10)}
+ local target = {torch.IntTensor{1,8}, torch.randn(2,10)}
+ local nll = nn.ClassNLLCriterion()
+ local mse = nn.MSECriterion()
+ local pc = nn.ParallelCriterion():add(nll, 0.5):add(mse)
+ local output = pc:forward(input, target)
+ local output2 = nll:forward(input[1], target[1])/2 + mse:forward(input[2], target[2])
+ mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion forward error")
+ local gradInput = pc:backward(input, target)
+ local gradInput2 = {nll:backward(input[1], target[1]):clone():div(2), mse:backward(input[2], target[2])}
+ mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion backward error 1")
+ mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion backward error 2")
+ -- test type
+ pc:float()
+ gradInput[1], gradInput[2] = gradInput[1]:clone(), gradInput[2]:clone()
+ local input3 = {input[1]:float(), input[2]:float()}
+ local target3 = {target[1]:float(), target[2]:float()}
+ local output3 = pc:forward(input3, target3)
+ local gradInput3 = pc:backward(input3, target3)
+ mytester:assert(math.abs(output3 - output) < 0.00001, "ParallelCriterion forward error type")
+ mytester:assertTensorEq(gradInput[1]:float(), gradInput3[1], 0.000001, "ParallelCriterion backward error 1 type")
+ mytester:assertTensorEq(gradInput[2]:float(), gradInput3[2], 0.000001, "ParallelCriterion backward error 2 type")
+ -- test repeatTarget
+ local input = {torch.rand(2,10), torch.randn(2,10)}
+ local target = torch.randn(2,10)
+ local mse = nn.MSECriterion()
+ local pc = nn.ParallelCriterion(true):add(mse, 0.5):add(mse:clone())
+ local output = pc:forward(input, target)
+ local output2 = mse:forward(input[1], target)/2 + mse:forward(input[2], target)
+ mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion repeatTarget forward error")
+ local gradInput = pc:backward(input, target)
+ local gradInput2 = {mse:backward(input[1], target):clone():div(2), mse:backward(input[2], target)}
+ mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion repeatTarget backward error 1")
+ mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion repeatTarget backward error 2")
+end
+
+function nntest.MultiCriterion()
+ local input = torch.rand(2,10)
+ local target = torch.IntTensor{1,8}
+ local nll = nn.ClassNLLCriterion()
+ local nll2 = nn.CrossEntropyCriterion()
+ local mc = nn.MultiCriterion():add(nll, 0.5):add(nll2)
+ local output = mc:forward(input, target)
+ local output2 = nll:forward(input, target)/2 + nll2:forward(input, target)
+ mytester:assert(math.abs(output2 - output) < 0.00001, "MultiCriterion forward error")
+ local gradInput = mc:backward(input, target)
+ local gradInput2 = nll:backward(input, target):clone():div(2):add(nll2:backward(input, target))
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "MultiCriterion backward error ")
+ -- test type
+ mc:float()
+ gradInput = gradInput:clone()
+ local input3 = input:float()
+ local target3 = target:float()
+ local output3 = mc:forward(input3, target3)
+ local gradInput3 = mc:backward(input3, target3)
+ mytester:assert(math.abs(output3 - output) < 0.00001, "MultiCriterion forward error type")
+ mytester:assertTensorEq(gradInput:float(), gradInput3, 0.000001, "MultiCriterion backward error type")
+end
+
function nntest.WeightedMSECriterion()
local input = torch.rand(10)
local target = input:clone():add(torch.rand(10))