From f5a9cd55277c9d101dc50f42703ce37cf6482250 Mon Sep 17 00:00:00 2001 From: fsuzanomassa Date: Wed, 22 Apr 2015 18:47:52 +0200 Subject: Adding in-place AddConstant and MulConstant --- test.lua | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) (limited to 'test.lua') diff --git a/test.lua b/test.lua index 23c7fbd..e0afbbb 100644 --- a/test.lua +++ b/test.lua @@ -2528,6 +2528,43 @@ function nntest.AddConstant() -- Test BPROP local err = jac.testJacobian(mod, input) mytester:assertlt(err, precision, 'bprop error ') + + -- inplace comparisons + local ini = math.random(3,5) + local inj = math.random(3,5) + local ink = math.random(3,5) + local constant = torch.uniform()*math.random(1,10) + + local input1 = torch.rand(ink, inj, ini) + local input2 = input1:clone() + + local module1 = nn.AddConstant(constant,true) + local module2 = nn.AddConstant(constant) + + local gradOutput1 = torch.rand(ink, inj, ini) + local gradOutput2 = gradOutput1:clone() + + local out1 = module1:forward(input1) + local out2 = module2:forward(input2) + + mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) .. + ' - in-place forward err ') + + local gradInput1 = module1:backward(input1, gradOutput1) + local gradInput2 = module2:backward(input2, gradOutput2) + + mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), + torch.typename(module1) .. ' - in-place backward err ') + + local input1 = torch.rand(ink, inj, ini) + local input2 = input1:clone() + + module1:forward(input1) + module1:backward(module1.output,torch.rand(input1:size())) + + local err = (input1-input2):abs():max() + mytester:asserteq(err, 0, torch.typename(module1) .. + ' - inplace input change err ') end function nntest.MulConstant() @@ -2548,6 +2585,43 @@ function nntest.MulConstant() -- Test BPROP local err = jac.testJacobian(mod, input) mytester:assertlt(err, precision, 'bprop error ') + + -- inplace comparisons + local ini = math.random(3,5) + local inj = math.random(3,5) + local ink = math.random(3,5) + local constant = torch.uniform()*math.random(1,10) + + local input1 = torch.rand(ink, inj, ini) + local input2 = input1:clone() + + local module1 = nn.MulConstant(constant,true) + local module2 = nn.MulConstant(constant) + + local gradOutput1 = torch.rand(ink, inj, ini) + local gradOutput2 = gradOutput1:clone() + + local out1 = module1:forward(input1) + local out2 = module2:forward(input2) + + mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) .. + ' - in-place forward err ') + + local gradInput1 = module1:backward(input1, gradOutput1) + local gradInput2 = module2:backward(input2, gradOutput2) + + mytester:asserteq(0, (gradInput1-gradInput2):abs():max(), + torch.typename(module1) .. ' - in-place backward err ') + + local input1 = torch.rand(ink, inj, ini) + local input2 = input1:clone() + + module1:forward(input1) + module1:backward(module1.output,torch.rand(input1:size())) + + local err = (input1-input2):abs():max() + mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) .. + ' - inplace input change err ') end function nntest.Copy() -- cgit v1.2.3