Welcome to mirror list, hosted at ThFree Co, Russian Federation.

bntest.lua « test - github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 8ebd1fa34e6930ad598392c383cde880c3d350cf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
require 'cunn'
require 'cudnn'

local h=5
local w=5
local bsz=4
local from=4
local input = torch.randn(bsz,from,h,w):cuda()
local gradOutput = torch.randn(bsz,from,h,w):cuda()
local cbn = cudnn.SpatialBatchNormalization(bsz, 1e-3):cuda()
local gbn = nn.SpatialBatchNormalization(bsz, 1e-3):cuda()
local groundtruth = gbn:forward(input)
local rescuda = cbn:forward(input)
local resgrad = cbn:backward(input, gradOutput)
local groundgrad = gbn:backward(input, gradOutput)
local error = (rescuda:float() - groundtruth:float()):abs():max()
print("error",error)
error = (resgrad:float() - groundgrad:float()):abs():max()
print("error back",error)