Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-01-10 02:11:06 +0300
committerSoumith Chintala <soumith@gmail.com>2015-01-10 02:11:06 +0300
commit1f21615bbb5110edbf380162d6c1e47ca72406b4 (patch)
tree06421b3d9e5e99beedcc8c3de500cf7c3a3667f7 /test.lua
parentc596c339786cf0674ef31b43a6c243678bb000e2 (diff)
parent517c6c0e36046f13167a4b77b06c464747ddc0fc (diff)
Merge pull request #133 from nicholas-leonard/weightedeuclidean2
WeightedEuclidean optimizations
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua93
1 files changed, 68 insertions, 25 deletions
diff --git a/test.lua b/test.lua
index 65ff3b1..f2b3e90 100644
--- a/test.lua
+++ b/test.lua
@@ -511,44 +511,87 @@ function nntest.Euclidean()
end
function nntest.WeightedEuclidean()
- local ini = math.random(3,5)
- local inj = math.random(13,5)
- local input = torch.Tensor(ini):zero()
+ local ini = math.random(5,7)
+ local inj = math.random(5,7)
+ local input = torch.randn(ini)
+ local gradOutput = torch.randn(inj)
local module = nn.WeightedEuclidean(ini,inj)
+ local output = module:forward(input):clone()
+
+ local output2 = torch.Tensor(inj):zero()
+ local temp = input:clone()
+ for o = 1,module.weight:size(2) do
+ temp:copy(input):add(-1,module.weight:select(2,o))
+ temp:cmul(temp)
+ temp:cmul(module.diagCov:select(2,o)):cmul(module.diagCov:select(2,o))
+ output2[o] = math.sqrt(temp:sum())
+ end
+ mytester:assertTensorEq(output, output2, 0.000001, 'WeightedEuclidean forward 1D err')
+
+ local input2 = torch.randn(8, ini)
+ input2[2]:copy(input)
+ local output2 = module:forward(input2)
+ mytester:assertTensorEq(output2[2], output, 0.000001, 'WeightedEuclidean forward 2D err')
+
+ local output = module:forward(input):clone()
+ module:zeroGradParameters()
+ local gradInput = module:backward(input, gradOutput, 1):clone()
+ local gradInput2 = torch.zeros(ini)
+ for o = 1,module.weight:size(2) do
+ temp:copy(input)
+ temp:add(-1,module.weight:select(2,o))
+ temp:cmul(module.diagCov:select(2,o)):cmul(module.diagCov:select(2,o))
+ temp:mul(gradOutput[o]/output[o])
+ gradInput2:add(temp)
+ end
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'WeightedEuclidean updateGradInput 1D err')
+
+ local gradWeight = module.gradWeight:clone():zero()
+ local gradDiagCov = module.gradDiagCov:clone():zero()
+ for o = 1,module.weight:size(2) do
+ if output[o] ~= 0 then
+ temp:copy(module.weight:select(2,o)):add(-1,input)
+ temp:cmul(module.diagCov:select(2,o)):cmul(module.diagCov:select(2,o))
+ temp:mul(gradOutput[o]/output[o])
+ gradWeight:select(2,o):add(temp)
+
+ temp:copy(module.weight:select(2,o)):add(-1,input)
+ temp:cmul(temp)
+ temp:cmul(module.diagCov:select(2,o))
+ temp:mul(gradOutput[o]/output[o])
+ gradDiagCov:select(2,o):add(temp)
+ end
+ end
+ mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'WeightedEuclidean accGradParameters gradWeight 1D err')
+ mytester:assertTensorEq(gradDiagCov, module.gradDiagCov, 0.000001, 'WeightedEuclidean accGradParameters gradDiagCov 1D err')
+
+ local input2 = input:view(1, -1):repeatTensor(8, 1)
+ local gradOutput2 = gradOutput:view(1, -1):repeatTensor(8, 1)
+ local output2 = module:forward(input2)
+ module:zeroGradParameters()
+ local gradInput2 = module:backward(input2, gradOutput2, 1/8)
+ mytester:assertTensorEq(gradInput2[2], gradInput, 0.000001, 'WeightedEuclidean updateGradInput 2D err')
+
+ mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'WeightedEuclidean accGradParameters gradWeight 2D err')
+ mytester:assertTensorEq(gradDiagCov, module.gradDiagCov, 0.000001, 'WeightedEuclidean accGradParameters gradDiagCov 2D err')
+
+ input:zero()
+ module.fastBackward = false
+
local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')
local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight)
mytester:assertlt(err,precision, 'error on weight ')
- local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias)
+ local err = jac.testJacobianParameters(module, input, module.diagCov, module.gradDiagCov)
mytester:assertlt(err,precision, 'error on bias ')
local ferr,berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
- -- test batch
- local bs = math.random(3,5)
- input:uniform(0,1)
- local output = module:forward(input):clone()
- module:zeroGradParameters()
- local gradInput = module:backward(input, output):clone()
- local params, gradParams = module:parameters()
- for i=1,#params do
- params[i] = params[i]:clone()
- end
- local input2 = input:view(1, -1):repeatTensor(bs, 1)
- local output2 = module:forward(input2)
- module:zeroGradParameters()
- local gradInput2 = module:backward(input2, output2, 1/bs)
- local params2, gradParams2 = module:parameters()
- mytester:assertTensorEq(output2[bs-1], output, 0.000001, "error in batch updateOutput")
- mytester:assertTensorEq(gradInput2[bs-1], gradInput, 0.000001, "error in batch updateGradInput")
- mytester:assertTensorEq(gradParams[1], gradParams2[1], 0.000001, "error in batch accGradParameters (gradTemplates)")
- mytester:assertTensorEq(gradParams[2], gradParams2[2], 0.000001, "error in batch accGradParameters (gradDiagCov)")
-
input:zero()
module:zeroGradParameters()
local err = jac.testJacobian(module,input)
@@ -557,7 +600,7 @@ function nntest.WeightedEuclidean()
local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight)
mytester:assertlt(err,precision, 'error on weight ')
- local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias)
+ local err = jac.testJacobianParameters(module, input, module.diagCov, module.gradDiagCov)
mytester:assertlt(err,precision, 'error on bias ')
local ferr,berr = jac.testIO(module,input2)