diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index a448317..aac45b4 100644 --- a/test/test.lua +++ b/test/test.lua @@ -962,6 +962,16 @@ function cudnntest.ReLU_batch() nonlinBatch('ReLU') end +function cudnntest.ClippedReLU_single() + local input = torch.randn(1, 32):cuda() + local ceiling = 0.1 + local module = cudnn.ClippedReLU(true, ceiling):cuda() + local output = module:forward(input) + local expectedOutput = input:clone() + expectedOutput[expectedOutput:ge(ceiling)] = ceiling + mytester:assertTensorEq(output, expectedOutput) +end + function cudnntest.Tanh_single() nonlinSingle('Tanh') end |