From 3af4cebeef5f66f4a390e326ea7f2b3be3f0f370 Mon Sep 17 00:00:00 2001 From: nicholas-leonard Date: Wed, 9 Jul 2014 17:00:26 -0400 Subject: ReLU unit test --- test/test.lua | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'test') diff --git a/test/test.lua b/test/test.lua index cc82e9e..45cc2fe 100644 --- a/test/test.lua +++ b/test/test.lua @@ -78,6 +78,18 @@ function nntest.Dropout() mytester:assert(math.abs(gradInput:mean() - (1-p)) < 0.05, 'dropout gradInput') end +function nntest.ReLU() + local input = torch.randn(3,4) + local gradOutput = torch.randn(3,4) + local module = nn.ReLU() + local output = module:forward(input) + local output2 = input:clone():gt(input, 0):cmul(input) + mytester:assertTensorEq(output, output2, 0.000001, 'ReLU output') + local gradInput = module:backward(input, gradOutput) + local gradInput2 = input:clone():gt(input, 0):cmul(gradOutput) + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, 'ReLU gradInput') +end + function nntest.Exp() local ini = math.random(10,20) local inj = math.random(10,20) -- cgit v1.2.3