diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-06-23 22:47:46 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2016-06-23 22:47:46 +0300 |
commit | 3fd281c6e5dcb90bc80ef0083dc778a164c31159 (patch) | |
tree | cc78bf1382a738d917f8c0a8f0d4623edcb0047c /test | |
parent | cc2d151af59715084a6e837cab873907cbffa22b (diff) |
deal with fp16 batchnorm
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/test/test.lua b/test/test.lua index e82f1eb..c612771 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1296,7 +1296,6 @@ function cudnntest.SpatialLogSoftMax() end local function testBatchNormalization(moduleName, inputSize) - if testparams.test_type == 'torch.CudaHalfTensor' then return end local input = torch.randn(table.unpack(inputSize)):cuda() local gradOutput = torch.randn(table.unpack(inputSize)):cuda() local cbn = cast(cudnn[moduleName](inputSize[2], 1e-3)) @@ -1331,7 +1330,7 @@ local function testBatchNormalization(moduleName, inputSize) local function testFWD(cbn, gbn) cbn:evaluate() gbn:evaluate() - local rescuda = cbn:forward(input:type(cbn:type())) + local rescuda = cbn:forward(cast(input)) local groundtruth = gbn:forward(input) local error = rescuda:float() - groundtruth:float() |