diff options
Diffstat (limited to 'test/test.lua')
-rw-r--r-- | test/test.lua | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 9ecc923..73426fb 100644 --- a/test/test.lua +++ b/test/test.lua @@ -60,6 +60,36 @@ function nntest.CMul() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') end +function nntest.Dropout() + local p = 0.2 --prob of droping out a neuron + local input = torch.Tensor(1000):fill((1-p)) + local module = nn.Dropout(p) + -- version 2 + local output = module:forward(input) + mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output') + local gradInput = module:backward(input, input) + mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput') + -- version 1 (old nnx version) + local input = input:fill(1) + local module = nn.Dropout(p,true) + local output = module:forward(input) + mytester:assert(math.abs(output:mean() - (1-p)) < 0.05, 'dropout output') + local gradInput = module:backward(input, input) + mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput') +end + +function nntest.ReLU() + local input = torch.randn(3,4) + local gradOutput = torch.randn(3,4) + local module = nn.ReLU() + local output = module:forward(input) + local output2 = input:clone():gt(input, 0):cmul(input) + mytester:assertTensorEq(output, output2, 0.000001, 'ReLU output') + local gradInput = module:backward(input, gradOutput) + local gradInput2 = input:clone():gt(input, 0):cmul(gradOutput) + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'ReLU gradInput') +end + function nntest.Exp() local ini = math.random(10,20) local inj = math.random(10,20) @@ -1869,6 +1899,49 @@ function nntest.SplitTable() end end +function nntest.ElementTable() + local input = { + torch.rand(3,4,5), torch.rand(3,4,5), + {torch.rand(3,4,5)}, + {torch.rand(3,4,5), {torch.rand(3,4,5)}} + } + local gradOutputs = { + torch.rand(3,4,5), torch.rand(3,4,5), + {torch.rand(3,4,5)}, + {torch.rand(3,4,5), {torch.rand(3,4,5)}} + } + local zeros = { + torch.Tensor(3,4,5):zero(), torch.Tensor(3,4,5):zero(), + {torch.Tensor(3,4,5):zero()}, + {torch.Tensor(3,4,5):zero(), {torch.Tensor(3,4,5):zero()}} + } + local function equal(t1, t2, msg) + if (torch.type(t1) == "table") then + for k, v in pairs(t2) do + equal(t1[k], t2[k]) + end + else + mytester:assertTensorEq(t1, t2, 0.00001, msg) + end + end + local nonIdx = {2,3,4,1} + local module + for idx = 1,#input do + module = nn.ElementTable(idx) + local output = module:forward(input) + equal(output, input[idx], "output dimension " .. idx) + local gradInput = module:backward(input, gradOutputs[idx]) + equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) + end + module:float() + local idx = #input + local output = module:forward(input) + equal(output, input[idx], "type output") + local gradInput = module:backward(input, gradOutputs[idx]) + equal(gradInput[idx], gradOutputs[idx], "gradInput[idx] dimension " .. idx) + equal(gradInput[nonIdx[idx]], zeros[nonIdx[idx]], "gradInput[nonIdx] dimension " .. idx) +end function nntest.View() local input = torch.rand(10) |