Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorJonathan Tompson <tompson@cims.nyu.edu>2014-06-25 20:14:12 +0400
committerJonathan Tompson <tompson@cims.nyu.edu>2014-06-25 20:14:12 +0400
commite406feabe2b1bd7f6b7a15826aff1e925fe713d6 (patch)
tree8177ddef413e068c95882e00d228feca0900ebaf /test
parentea9cc1df751ddb144c08a13aab3add1ab0ce90a1 (diff)
Added very simple add and mul constant modules.
Diffstat (limited to 'test')
-rw-r--r--test/test.lua40
1 files changed, 40 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index be17fd7..775dded 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1789,6 +1789,46 @@ function nntest.LookupTable()
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
+
+function nntest.AddConstant()
+ local nbatch = torch.random(3, 5)
+ local f = torch.random(3, 5)
+ local h = torch.random(10,20)
+ local w = torch.random(10,20)
+ local input = torch.rand(nbatch, f, h, w):mul(20):add(-10) -- [-10, 10]
+
+ local constant = torch.randn(1):squeeze()
+ local mod = nn.AddConstant(constant)
+
+ -- Test FPROP
+ local output = mod:forward(input)
+ local delta = output - input
+ mytester:assertlt(delta:add(-constant):abs():max(), precision, 'fprop error')
+
+ -- Test BPROP
+ local err = jac.testJacobian(mod, input)
+ mytester:assertlt(err, precision, 'bprop error ')
+end
+
+function nntest.MulConstant()
+ local nbatch = torch.random(3, 5)
+ local f = torch.random(3, 5)
+ local h = torch.random(10,20)
+ local w = torch.random(10,20)
+ local input = torch.rand(nbatch, f, h, w):mul(20):add(-10) -- [-10, 10]
+
+ local constant = torch.randn(1):squeeze()
+ local mod = nn.MulConstant(constant)
+
+ -- Test FPROP
+ local output = mod:forward(input)
+ local scale = output:clone():cdiv(input)
+ mytester:assertlt(scale:add(-constant):abs():max(), precision, 'fprop error')
+
+ -- Test BPROP
+ local err = jac.testJacobian(mod, input)
+ mytester:assertlt(err, precision, 'bprop error ')
+end
mytester:add(nntest)