diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index be17fd7..775dded 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1789,6 +1789,46 @@ function nntest.LookupTable() mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end + +function nntest.AddConstant() + local nbatch = torch.random(3, 5) + local f = torch.random(3, 5) + local h = torch.random(10,20) + local w = torch.random(10,20) + local input = torch.rand(nbatch, f, h, w):mul(20):add(-10) -- [-10, 10] + + local constant = torch.randn(1):squeeze() + local mod = nn.AddConstant(constant) + + -- Test FPROP + local output = mod:forward(input) + local delta = output - input + mytester:assertlt(delta:add(-constant):abs():max(), precision, 'fprop error') + + -- Test BPROP + local err = jac.testJacobian(mod, input) + mytester:assertlt(err, precision, 'bprop error ') +end + +function nntest.MulConstant() + local nbatch = torch.random(3, 5) + local f = torch.random(3, 5) + local h = torch.random(10,20) + local w = torch.random(10,20) + local input = torch.rand(nbatch, f, h, w):mul(20):add(-10) -- [-10, 10] + + local constant = torch.randn(1):squeeze() + local mod = nn.MulConstant(constant) + + -- Test FPROP + local output = mod:forward(input) + local scale = output:clone():cdiv(input) + mytester:assertlt(scale:add(-constant):abs():max(), precision, 'fprop error') + + -- Test BPROP + local err = jac.testJacobian(mod, input) + mytester:assertlt(err, precision, 'bprop error ') +end mytester:add(nntest) |