From e406feabe2b1bd7f6b7a15826aff1e925fe713d6 Mon Sep 17 00:00:00 2001 From: Jonathan Tompson Date: Wed, 25 Jun 2014 12:14:12 -0400 Subject: Added very simple add and mul constant modules. --- test/test.lua | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) (limited to 'test') diff --git a/test/test.lua b/test/test.lua index be17fd7..775dded 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1789,6 +1789,46 @@ function nntest.LookupTable() mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end + +function nntest.AddConstant() + local nbatch = torch.random(3, 5) + local f = torch.random(3, 5) + local h = torch.random(10,20) + local w = torch.random(10,20) + local input = torch.rand(nbatch, f, h, w):mul(20):add(-10) -- [-10, 10] + + local constant = torch.randn(1):squeeze() + local mod = nn.AddConstant(constant) + + -- Test FPROP + local output = mod:forward(input) + local delta = output - input + mytester:assertlt(delta:add(-constant):abs():max(), precision, 'fprop error') + + -- Test BPROP + local err = jac.testJacobian(mod, input) + mytester:assertlt(err, precision, 'bprop error ') +end + +function nntest.MulConstant() + local nbatch = torch.random(3, 5) + local f = torch.random(3, 5) + local h = torch.random(10,20) + local w = torch.random(10,20) + local input = torch.rand(nbatch, f, h, w):mul(20):add(-10) -- [-10, 10] + + local constant = torch.randn(1):squeeze() + local mod = nn.MulConstant(constant) + + -- Test FPROP + local output = mod:forward(input) + local scale = output:clone():cdiv(input) + mytester:assertlt(scale:add(-constant):abs():max(), precision, 'fprop error') + + -- Test BPROP + local err = jac.testJacobian(mod, input) + mytester:assertlt(err, precision, 'bprop error ') +end mytester:add(nntest) -- cgit v1.2.3