Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-06-23 22:47:46 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-06-23 22:47:46 +0300
commit3fd281c6e5dcb90bc80ef0083dc778a164c31159 (patch)
treecc78bf1382a738d917f8c0a8f0d4623edcb0047c /test
parentcc2d151af59715084a6e837cab873907cbffa22b (diff)
deal with fp16 batchnorm
Diffstat (limited to 'test')
-rw-r--r--test/test.lua3
1 files changed, 1 insertions, 2 deletions
diff --git a/test/test.lua b/test/test.lua
index e82f1eb..c612771 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1296,7 +1296,6 @@ function cudnntest.SpatialLogSoftMax()
end
local function testBatchNormalization(moduleName, inputSize)
- if testparams.test_type == 'torch.CudaHalfTensor' then return end
local input = torch.randn(table.unpack(inputSize)):cuda()
local gradOutput = torch.randn(table.unpack(inputSize)):cuda()
local cbn = cast(cudnn[moduleName](inputSize[2], 1e-3))
@@ -1331,7 +1330,7 @@ local function testBatchNormalization(moduleName, inputSize)
local function testFWD(cbn, gbn)
cbn:evaluate()
gbn:evaluate()
- local rescuda = cbn:forward(input:type(cbn:type()))
+ local rescuda = cbn:forward(cast(input))
local groundtruth = gbn:forward(input)
local error = rescuda:float() - groundtruth:float()