Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-12-23 21:52:31 +0300
committerSoumith Chintala <soumith@gmail.com>2014-12-23 21:52:31 +0300
commit289c6a2d91dfec05cc5c55105353e408b8541334 (patch)
tree14138f62030a9e312f6d2cc3502a1f263814440c /test.lua
parenteecf41831e177b08dda3bcc51c6aa70c9df2f2d1 (diff)
parent4178c4ef3e17425940022047d7dd12a645d3ed11 (diff)
Merge pull request #119 from nicholas-leonard/weightedeuclidean
WeightedEuclidean batch mode
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua37
1 files changed, 36 insertions, 1 deletions
diff --git a/test.lua b/test.lua
index 29de8bc..56298e0 100644
--- a/test.lua
+++ b/test.lua
@@ -477,7 +477,7 @@ function nntest.WeightedEuclidean()
local inj = math.random(13,5)
local input = torch.Tensor(ini):zero()
local module = nn.WeightedEuclidean(ini,inj)
-
+
local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')
@@ -490,6 +490,41 @@ function nntest.WeightedEuclidean()
local ferr,berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+
+ -- test batch
+ local bs = math.random(3,5)
+ input:uniform(0,1)
+ local output = module:forward(input):clone()
+ module:zeroGradParameters()
+ local gradInput = module:backward(input, output):clone()
+ local params, gradParams = module:parameters()
+ for i=1,#params do
+ params[i] = params[i]:clone()
+ end
+ local input2 = input:view(1, -1):repeatTensor(bs, 1)
+ local output2 = module:forward(input2)
+ module:zeroGradParameters()
+ local gradInput2 = module:backward(input2, output2, 1/bs)
+ local params2, gradParams2 = module:parameters()
+ mytester:assertTensorEq(output2[bs-1], output, 0.000001, "error in batch updateOutput")
+ mytester:assertTensorEq(gradInput2[bs-1], gradInput, 0.000001, "error in batch updateGradInput")
+ mytester:assertTensorEq(gradParams[1], gradParams2[1], 0.000001, "error in batch accGradParameters (gradTemplates)")
+ mytester:assertTensorEq(gradParams[2], gradParams2[2], 0.000001, "error in batch accGradParameters (gradDiagCov)")
+
+ input:zero()
+ module:zeroGradParameters()
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'error on state ')
+
+ local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight)
+ mytester:assertlt(err,precision, 'error on weight ')
+
+ local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias)
+ mytester:assertlt(err,precision, 'error on bias ')
+
+ local ferr,berr = jac.testIO(module,input2)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
local function criterionJacobianTest1D(cri, input, target)