Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-01-10 02:11:57 +0300
committerSoumith Chintala <soumith@gmail.com>2015-01-10 02:11:57 +0300
commite8fadc69a11fb24829cafbe800e6a6e4948899fd (patch)
treede13ce2049506eba18310733216cae7447a5ae4f /test.lua
parent1f21615bbb5110edbf380162d6c1e47ca72406b4 (diff)
parentdace3a23ba412ffdd1669b20aa550f4ee9451d4c (diff)
Merge pull request #128 from nicholas-leonard/euclidean2
Euclidean batch support
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua49
1 files changed, 47 insertions, 2 deletions
diff --git a/test.lua b/test.lua
index f2b3e90..ab91abb 100644
--- a/test.lua
+++ b/test.lua
@@ -496,9 +496,54 @@ end
function nntest.Euclidean()
local ini = math.random(5,7)
local inj = math.random(5,7)
- local input = torch.Tensor(ini):zero()
+ local input = torch.randn(ini)
+ local gradOutput = torch.randn(inj)
local module = nn.Euclidean(ini,inj)
-
+ local output = module:forward(input):clone()
+
+ local output2 = torch.Tensor(inj):zero()
+ for o = 1,module.weight:size(2) do
+ output2[o] = input:dist(module.weight:select(2,o))
+ end
+ mytester:assertTensorEq(output, output2, 0.000001, 'Euclidean forward 1D err')
+
+ local input2 = torch.randn(8, ini)
+ input2[2]:copy(input)
+ local output2 = module:forward(input2)
+ mytester:assertTensorEq(output2[2], output, 0.000001, 'Euclidean forward 2D err')
+
+ local output = module:forward(input):clone()
+ module:zeroGradParameters()
+ local gradInput = module:backward(input, gradOutput, 1):clone()
+ local gradInput2 = torch.zeros(ini)
+ local temp = input:clone()
+ for o = 1,module.weight:size(2) do
+ temp:copy(input)
+ temp:add(-1,module.weight:select(2,o))
+ temp:mul(gradOutput[o]/output[o])
+ gradInput2:add(temp)
+ end
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'Euclidean updateGradInput 1D err')
+
+ local gradWeight = module.gradWeight:clone():zero()
+ for o = 1,module.weight:size(2) do
+ temp:copy(module.weight:select(2,o)):add(-1,input)
+ temp:mul(gradOutput[o]/output[o])
+ gradWeight:select(2,o):add(1, temp)
+ end
+ mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'Euclidean accGradParameters 1D err')
+
+ local input2 = input:view(1, -1):repeatTensor(8, 1)
+ local gradOutput2 = gradOutput:view(1, -1):repeatTensor(8, 1)
+ local output2 = module:forward(input2)
+ module:zeroGradParameters()
+ local gradInput2 = module:backward(input2, gradOutput2, 1/8)
+ mytester:assertTensorEq(gradInput2[2], gradInput, 0.000001, 'Euclidean updateGradInput 2D err')
+
+ mytester:assertTensorEq(gradWeight, module.gradWeight, 0.000001, 'Euclidean accGradParameters 2D err')
+
+ input:zero()
+ module.fastBackward = false
local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')