From 3fd281c6e5dcb90bc80ef0083dc778a164c31159 Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Thu, 23 Jun 2016 21:47:46 +0200 Subject: deal with fp16 batchnorm --- test/test.lua | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'test') diff --git a/test/test.lua b/test/test.lua index e82f1eb..c612771 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1296,7 +1296,6 @@ function cudnntest.SpatialLogSoftMax() end local function testBatchNormalization(moduleName, inputSize) - if testparams.test_type == 'torch.CudaHalfTensor' then return end local input = torch.randn(table.unpack(inputSize)):cuda() local gradOutput = torch.randn(table.unpack(inputSize)):cuda() local cbn = cast(cudnn[moduleName](inputSize[2], 1e-3)) @@ -1331,7 +1330,7 @@ local function testBatchNormalization(moduleName, inputSize) local function testFWD(cbn, gbn) cbn:evaluate() gbn:evaluate() - local rescuda = cbn:forward(input:type(cbn:type())) + local rescuda = cbn:forward(cast(input)) local groundtruth = gbn:forward(input) local error = rescuda:float() - groundtruth:float() -- cgit v1.2.3