From 3fc211c48b5ad270f374cb18733b0354cc4b9335 Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Tue, 28 Apr 2015 17:40:07 -0400 Subject: ParallelCriterion --- test.lua | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) (limited to 'test.lua') diff --git a/test.lua b/test.lua index a28674a..9414a66 100644 --- a/test.lua +++ b/test.lua @@ -811,6 +811,66 @@ function nntest.MarginRankingCriterion() mytester:assert(torch.type(gradInput2[2]) == 'torch.FloatTensor', "MRC:type() error 2") end +function nntest.ParallelCriterion() + local input = {torch.rand(2,10), torch.randn(2,10)} + local target = {torch.IntTensor{1,8}, torch.randn(2,10)} + local nll = nn.ClassNLLCriterion() + local mse = nn.MSECriterion() + local pc = nn.ParallelCriterion():add(nll, 0.5):add(mse) + local output = pc:forward(input, target) + local output2 = nll:forward(input[1], target[1])/2 + mse:forward(input[2], target[2]) + mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion forward error") + local gradInput = pc:backward(input, target) + local gradInput2 = {nll:backward(input[1], target[1]):clone():div(2), mse:backward(input[2], target[2])} + mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion backward error 1") + mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion backward error 2") + -- test type + pc:float() + gradInput[1], gradInput[2] = gradInput[1]:clone(), gradInput[2]:clone() + local input3 = {input[1]:float(), input[2]:float()} + local target3 = {target[1]:float(), target[2]:float()} + local output3 = pc:forward(input3, target3) + local gradInput3 = pc:backward(input3, target3) + mytester:assert(math.abs(output3 - output) < 0.00001, "ParallelCriterion forward error type") + mytester:assertTensorEq(gradInput[1]:float(), gradInput3[1], 0.000001, "ParallelCriterion backward error 1 type") + mytester:assertTensorEq(gradInput[2]:float(), gradInput3[2], 0.000001, "ParallelCriterion backward error 2 type") + -- test repeatTarget + local input = {torch.rand(2,10), torch.randn(2,10)} + local target = torch.randn(2,10) + local mse = nn.MSECriterion() + local pc = nn.ParallelCriterion(true):add(mse, 0.5):add(mse:clone()) + local output = pc:forward(input, target) + local output2 = mse:forward(input[1], target)/2 + mse:forward(input[2], target) + mytester:assert(math.abs(output2 - output) < 0.00001, "ParallelCriterion repeatTarget forward error") + local gradInput = pc:backward(input, target) + local gradInput2 = {mse:backward(input[1], target):clone():div(2), mse:backward(input[2], target)} + mytester:assertTensorEq(gradInput[1], gradInput2[1], 0.000001, "ParallelCriterion repeatTarget backward error 1") + mytester:assertTensorEq(gradInput[2], gradInput2[2], 0.000001, "ParallelCriterion repeatTarget backward error 2") +end + +function nntest.MultiCriterion() + local input = torch.rand(2,10) + local target = torch.IntTensor{1,8} + local nll = nn.ClassNLLCriterion() + local nll2 = nn.CrossEntropyCriterion() + local mc = nn.MultiCriterion():add(nll, 0.5):add(nll2) + local output = mc:forward(input, target) + local output2 = nll:forward(input, target)/2 + nll2:forward(input, target) + mytester:assert(math.abs(output2 - output) < 0.00001, "MultiCriterion forward error") + local gradInput = mc:backward(input, target) + local gradInput2 = nll:backward(input, target):clone():div(2):add(nll2:backward(input, target)) + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "MultiCriterion backward error ") + -- test type + mc:float() + gradInput = gradInput:clone() + local input3 = input:float() + local target3 = target:float() + local output3 = mc:forward(input3, target3) + local gradInput3 = mc:backward(input3, target3) + mytester:assert(math.abs(output3 - output) < 0.00001, "MultiCriterion forward error type") + mytester:assertTensorEq(gradInput:float(), gradInput3, 0.000001, "MultiCriterion backward error type") +end + function nntest.WeightedMSECriterion() local input = torch.rand(10) local target = input:clone():add(torch.rand(10)) -- cgit v1.2.3