Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test.lua3
1 files changed, 1 insertions, 2 deletions
diff --git a/test/test.lua b/test/test.lua
index e82f1eb..c612771 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1296,7 +1296,6 @@ function cudnntest.SpatialLogSoftMax()
end
local function testBatchNormalization(moduleName, inputSize)
- if testparams.test_type == 'torch.CudaHalfTensor' then return end
local input = torch.randn(table.unpack(inputSize)):cuda()
local gradOutput = torch.randn(table.unpack(inputSize)):cuda()
local cbn = cast(cudnn[moduleName](inputSize[2], 1e-3))
@@ -1331,7 +1330,7 @@ local function testBatchNormalization(moduleName, inputSize)
local function testFWD(cbn, gbn)
cbn:evaluate()
gbn:evaluate()
- local rescuda = cbn:forward(input:type(cbn:type()))
+ local rescuda = cbn:forward(cast(input))
local groundtruth = gbn:forward(input)
local error = rescuda:float() - groundtruth:float()