diff options
author | nicholas-leonard <nick@nikopia.org> | 2015-01-07 23:50:19 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2015-01-08 00:13:46 +0300 |
commit | 63489b23cc2aa136f51711a579861fa1ef536566 (patch) | |
tree | 46fe86dcf2ccdcc374090e00e4a9200be784defe /test.lua | |
parent | 6f53dd515034cc582832db79e23912885a749d1d (diff) |
CMul optimizations, doc and unit tests
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 26 |
1 files changed, 24 insertions, 2 deletions
@@ -74,7 +74,7 @@ function nntest.CMul() local inj = math.random(3,5) local ink = math.random(3,5) local input = torch.Tensor(ini,inj,ink):zero() - local module = nn.CMul(ini*inj*ink) + local module = nn.CMul(ini, inj, ink) -- 1D local err = jac.testJacobian(module,input) @@ -94,7 +94,29 @@ function nntest.CMul() -- 2D local nframe = math.random(50,70) local nframe = 5 - local input = torch.Tensor(nframe, ini,inj,ink):zero() + local input = torch.randn(nframe, ini,inj,ink) + local output = module:forward(input) + local output2 = torch.cmul(input, module.weight:view(1,ini,inj,ink):expandAs(input)) + mytester:assertTensorEq(output2, output, 0.000001, 'CMul forward 2D err') + + module:zeroGradParameters() + local gradWeight = module.gradWeight:clone() + local gradInput = module:backward(input, output) + local gradInput2 = gradInput:clone():zero() + local outputView = output:view(input:size(1), -1) + gradInput2:view(input:size(1), -1):addcmul(1, module.weight:view(1,-1):expandAs(outputView), outputView) + mytester:assertTensorEq(gradInput2, gradInput, 0.000001, 'CMul updateGradInput 2D err') + mytester:assert(gradInput:isSameSizeAs(input), 'CMul gradInput 2D size err') + + local inputView = input:view(nframe, -1) + local gradWeightView = gradWeight:view(1, -1) + for i=1,nframe do + gradWeightView:addcmul(1, inputView[i], outputView[i]) + end + mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'CMul accGradParameters 2D err') + mytester:assert(module.weight:isSameSizeAs(module.gradWeight), 'CMul gradWeight size err') + + input:zero() local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') |