diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-01-10 02:11:57 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-01-10 02:11:57 +0300 |
commit | e8fadc69a11fb24829cafbe800e6a6e4948899fd (patch) | |
tree | de13ce2049506eba18310733216cae7447a5ae4f /test.lua | |
parent | 1f21615bbb5110edbf380162d6c1e47ca72406b4 (diff) | |
parent | dace3a23ba412ffdd1669b20aa550f4ee9451d4c (diff) |
Merge pull request #128 from nicholas-leonard/euclidean2
Euclidean batch support
Diffstat (limited to 'test.lua')
-rw-r--r-- | test.lua | 49 |
1 files changed, 47 insertions, 2 deletions
@@ -496,9 +496,54 @@ end function nntest.Euclidean() local ini = math.random(5,7) local inj = math.random(5,7) - local input = torch.Tensor(ini):zero() + local input = torch.randn(ini) + local gradOutput = torch.randn(inj) local module = nn.Euclidean(ini,inj) - + local output = module:forward(input):clone() + + local output2 = torch.Tensor(inj):zero() + for o = 1,module.weight:size(2) do + output2[o] = input:dist(module.weight:select(2,o)) + end + mytester:assertTensorEq(output, output2, 0.000001, 'Euclidean forward 1D err') + + local input2 = torch.randn(8, ini) + input2[2]:copy(input) + local output2 = module:forward(input2) + mytester:assertTensorEq(output2[2], output, 0.000001, 'Euclidean forward 2D err') + + local output = module:forward(input):clone() + module:zeroGradParameters() + local gradInput = module:backward(input, gradOutput, 1):clone() + local gradInput2 = torch.zeros(ini) + local temp = input:clone() + for o = 1,module.weight:size(2) do + temp:copy(input) + temp:add(-1,module.weight:select(2,o)) + temp:mul(gradOutput[o]/output[o]) + gradInput2:add(temp) + end + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'Euclidean updateGradInput 1D err') + + local gradWeight = module.gradWeight:clone():zero() + for o = 1,module.weight:size(2) do + temp:copy(module.weight:select(2,o)):add(-1,input) + temp:mul(gradOutput[o]/output[o]) + gradWeight:select(2,o):add(1, temp) + end + mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'Euclidean accGradParameters 1D err') + + local input2 = input:view(1, -1):repeatTensor(8, 1) + local gradOutput2 = gradOutput:view(1, -1):repeatTensor(8, 1) + local output2 = module:forward(input2) + module:zeroGradParameters() + local gradInput2 = module:backward(input2, gradOutput2, 1/8) + mytester:assertTensorEq(gradInput2[2], gradInput, 0.000001, 'Euclidean updateGradInput 2D err') + + mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'Euclidean accGradParameters 2D err') + + input:zero() + module.fastBackward = false local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') |