Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-07-10 18:28:55 +0400
committerSoumith Chintala <soumith@gmail.com>2014-07-10 18:28:55 +0400
commit1625eef2d4448c75711c7a87811b8a9e81a9ac40 (patch)
treed4c9f0b38c1cba20fcf1178ec41c3736b3b40fbf /test
parent75a2279ef3dac76046f128a2d77e1ffd2dcd5397 (diff)
parentca59595974936b9ff42377677e7553627f397197 (diff)
Merge github.com:torch/nn into concattable_tableinput
Diffstat (limited to 'test')
-rw-r--r--test/test.lua73
1 files changed, 73 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 9ecc923..73426fb 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -60,6 +60,36 @@ function nntest.CMul()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.Dropout()
+ local p = 0.2 --prob of droping out a neuron
+ local input = torch.Tensor(1000):fill((1-p))
+ local module = nn.Dropout(p)
+ -- version 2
+ local output = module:forward(input)
+ mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output')
+ local gradInput = module:backward(input, input)
+ mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput')
+ -- version 1 (old nnx version)
+ local input = input:fill(1)
+ local module = nn.Dropout(p,true)
+ local output = module:forward(input)
+ mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output')
+ local gradInput = module:backward(input, input)
+ mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput')
+end
+
+function nntest.ReLU()
+ local input = torch.randn(3,4)
+ local gradOutput = torch.randn(3,4)
+ local module = nn.ReLU()
+ local output = module:forward(input)
+ local output2 = input:clone():gt(input, 0):cmul(input)
+ mytester:assertTensorEq(output, output2, 0.000001, 'ReLU output')
+ local gradInput = module:backward(input, gradOutput)
+ local gradInput2 = input:clone():gt(input, 0):cmul(gradOutput)
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'ReLU gradInput')
+end
+
function nntest.Exp()
local ini = math.random(10,20)
local inj = math.random(10,20)
@@ -1869,6 +1899,49 @@ function nntest.SplitTable()
end
end
+function nntest.ElementTable()
+ local input = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local gradOutputs = {
+ torch.rand(3,4,5), torch.rand(3,4,5),
+ {torch.rand(3,4,5)},
+ {torch.rand(3,4,5), {torch.rand(3,4,5)}}
+ }
+ local zeros = {
+ torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(),
+ {torch.Tensor(3,4,5):zero()},
+ {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}}
+ }
+ local function equal(t1, t2, msg)
+ if (torch.type(t1) == "table") then
+ for k, v in pairs(t2) do
+ equal(t1[k], t2[k])
+ end
+ else
+ mytester:assertTensorEq(t1, t2, 0.00001, msg)
+ end
+ end
+ local nonIdx = {2,3,4,1}
+ local module
+ for idx = 1,#input do
+ module = nn.ElementTable(idx)
+ local output = module:forward(input)
+ equal(output, input[idx], "output dimension " .. idx)
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+ end
+ module:float()
+ local idx = #input
+ local output = module:forward(input)
+ equal(output, input[idx], "type output")
+ local gradInput = module:backward(input, gradOutputs[idx])
+ equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx)
+ equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx)
+end
function nntest.View()
local input = torch.rand(10)