diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index cc82e9e..45cc2fe 100644 --- a/test/test.lua +++ b/test/test.lua @@ -78,6 +78,18 @@ function nntest.Dropout() mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput') end +function nntest.ReLU() + local input = torch.randn(3,4) + local gradOutput = torch.randn(3,4) + local module = nn.ReLU() + local output = module:forward(input) + local output2 = input:clone():gt(input, 0):cmul(input) + mytester:assertTensorEq(output, output2, 0.000001, 'ReLU output') + local gradInput = module:backward(input, gradOutput) + local gradInput2 = input:clone():gt(input, 0):cmul(gradOutput) + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'ReLU gradInput') +end + function nntest.Exp() local ini = math.random(10,20) local inj = math.random(10,20) |