Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfsuzanomassa <fvsmassa@gmail.com>2015-04-22 19:47:52 +0300
committerfsuzanomassa <fvsmassa@gmail.com>2015-04-22 19:47:52 +0300
commitf5a9cd55277c9d101dc50f42703ce37cf6482250 (patch)
treeb0023792d785e3b6830f8459c100b8b89ab7152d /test.lua
parent418624f67da0c61dd2a7205373e3ebe816a94aae (diff)
Adding in-place AddConstant and MulConstant
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua74
1 files changed, 74 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 23c7fbd..e0afbbb 100644
--- a/test.lua
+++ b/test.lua
@@ -2528,6 +2528,43 @@ function nntest.AddConstant()
-- Test BPROP
local err = jac.testJacobian(mod, input)
mytester:assertlt(err, precision, 'bprop error ')
+
+ -- inplace comparisons
+ local ini = math.random(3,5)
+ local inj = math.random(3,5)
+ local ink = math.random(3,5)
+ local constant = torch.uniform()*math.random(1,10)
+
+ local input1 = torch.rand(ink, inj, ini)
+ local input2 = input1:clone()
+
+ local module1 = nn.AddConstant(constant,true)
+ local module2 = nn.AddConstant(constant)
+
+ local gradOutput1 = torch.rand(ink, inj, ini)
+ local gradOutput2 = gradOutput1:clone()
+
+ local out1 = module1:forward(input1)
+ local out2 = module2:forward(input2)
+
+ mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) ..
+ ' - in-place forward err ')
+
+ local gradInput1 = module1:backward(input1, gradOutput1)
+ local gradInput2 = module2:backward(input2, gradOutput2)
+
+ mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
+ torch.typename(module1) .. ' - in-place backward err ')
+
+ local input1 = torch.rand(ink, inj, ini)
+ local input2 = input1:clone()
+
+ module1:forward(input1)
+ module1:backward(module1.output,torch.rand(input1:size()))
+
+ local err = (input1-input2):abs():max()
+ mytester:asserteq(err, 0, torch.typename(module1) ..
+ ' - inplace input change err ')
end
function nntest.MulConstant()
@@ -2548,6 +2585,43 @@ function nntest.MulConstant()
-- Test BPROP
local err = jac.testJacobian(mod, input)
mytester:assertlt(err, precision, 'bprop error ')
+
+ -- inplace comparisons
+ local ini = math.random(3,5)
+ local inj = math.random(3,5)
+ local ink = math.random(3,5)
+ local constant = torch.uniform()*math.random(1,10)
+
+ local input1 = torch.rand(ink, inj, ini)
+ local input2 = input1:clone()
+
+ local module1 = nn.MulConstant(constant,true)
+ local module2 = nn.MulConstant(constant)
+
+ local gradOutput1 = torch.rand(ink, inj, ini)
+ local gradOutput2 = gradOutput1:clone()
+
+ local out1 = module1:forward(input1)
+ local out2 = module2:forward(input2)
+
+ mytester:asserteq(0, (out1-out2):abs():max(), torch.typename(module1) ..
+ ' - in-place forward err ')
+
+ local gradInput1 = module1:backward(input1, gradOutput1)
+ local gradInput2 = module2:backward(input2, gradOutput2)
+
+ mytester:asserteq(0, (gradInput1-gradInput2):abs():max(),
+ torch.typename(module1) .. ' - in-place backward err ')
+
+ local input1 = torch.rand(ink, inj, ini)
+ local input2 = input1:clone()
+
+ module1:forward(input1)
+ module1:backward(module1.output,torch.rand(input1:size()))
+
+ local err = (input1-input2):abs():max()
+ mytester:assertalmosteq(err, 0, 1e-15, torch.typename(module1) ..
+ ' - inplace input change err ')
end
function nntest.Copy()