Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-11-09 00:06:52 +0300
committerGregory Chanan <gchanan@fb.com>2016-11-09 00:06:52 +0300
commitaa18682ea8471c9b8b1b949621149cec8d816b1d (patch)
treef9ed0b9c210ac3ac39c63b276be928b5ff020787 /test.lua
parentba27398bc5fd98ad0df4b995b725beb8cbf355d1 (diff)
Rebase BatchNormalization.
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua190
1 files changed, 84 insertions, 106 deletions
diff --git a/test.lua b/test.lua
index 491c34f..93b5344 100644
--- a/test.lua
+++ b/test.lua
@@ -820,135 +820,113 @@ end
local function BatchNormalization_forward(moduleName, inputSize)
local planes = inputSize[2]
- local tm = {}
- local title = moduleName .. '.forward ' .. table.concat(inputSize, 'x')
- times[title] = tm
- local input = torch.randn(table.unpack(inputSize))
- local sbnorm = nn[moduleName](planes)
- local groundtruth = sbnorm:forward(input)
- local a = torch.Timer()
- for i = 1,nloop do
- groundtruth = sbnorm:forward(input)
- end
- tm.cpu = a:time().real
+ for k, typename in ipairs(typenames) do
+ local input = torch.randn(table.unpack(inputSize)):type(typename)
- input = input:cuda()
- local gbnorm = nn[moduleName](planes):cuda()
- gbnorm.weight = sbnorm.weight:cuda()
- gbnorm.bias = sbnorm.bias:cuda()
- local rescuda = gbnorm:forward(input)
+ local ctype = t2cpu[typename]
+ input = input:type(ctype)
+ local sbnorm = nn[moduleName](planes):type(ctype)
+ local groundtruth = sbnorm:forward(input)
- a:reset()
- for i = 1,nloop do
- rescuda = gbnorm:forward(input)
- end
- cutorch.synchronize()
- tm.gpu = a:time().real
+ input = input:type(typename)
+ local gbnorm = nn[moduleName](planes):type(typename)
+ gbnorm.weight = sbnorm.weight:type(typename)
+ gbnorm.bias = sbnorm.bias:type(typename)
+ local rescuda = gbnorm:forward(input)
- local error = rescuda:float() - groundtruth
- mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward)')
- mytester:assertlt((gbnorm.running_mean:float() - sbnorm.running_mean):abs():max(),
- precision_forward, 'error on running_mean (forward)')
- mytester:assertlt((gbnorm.running_var:float() - sbnorm.running_var):abs():max(),
- precision_forward, 'error on running_var (forward)')
+ local error = rescuda:double() - groundtruth:double()
+ mytester:assertlt(error:abs():max(), precision_forward_type(precision_forward, typename),
+ string.format('error on state (forward) with %s', typename))
+ mytester:assertlt((gbnorm.running_mean:double() - sbnorm.running_mean:double()):abs():max(),
+ precision_forward_type(precision_forward, typename),
+ string.format('error on running_mean (forward) with %s', typenanme))
+ mytester:assertlt((gbnorm.running_var:double() - sbnorm.running_var:double()):abs():max(),
+ precision_forward_type(precision_forward, typename),
+ string.format('error on running_var (forward) with %s', typename))
+ end
end
local function BatchNormalization_forward_inference(moduleName, inputSize)
local planes = inputSize[2]
- local tm = {}
- local title = moduleName .. '.forward (evaluate) ' .. table.concat(inputSize, 'x')
- times[title] = tm
- local input = torch.randn(table.unpack(inputSize))
- local sbnorm = nn[moduleName](planes)
- sbnorm.running_mean:normal(1, 2)
- sbnorm.running_var:uniform(1e-3, 2)
- sbnorm:evaluate()
- local groundtruth = sbnorm:forward(input)
- local a = torch.Timer()
- for i = 1,nloop do
- groundtruth = sbnorm:forward(input)
- end
- tm.cpu = a:time().real
+ for k, typename in ipairs(typenames) do
+ local input = torch.randn(table.unpack(inputSize)):type(typename)
- input = input:cuda()
- local gbnorm = nn[moduleName](planes):cuda()
- gbnorm:evaluate()
- gbnorm.weight = sbnorm.weight:cuda()
- gbnorm.bias = sbnorm.bias:cuda()
- gbnorm.running_mean = sbnorm.running_mean:cuda()
- gbnorm.running_var = sbnorm.running_var:cuda()
- local rescuda = gbnorm:forward(input)
- a:reset()
- for i = 1,nloop do
- rescuda = gbnorm:forward(input)
- end
- cutorch.synchronize()
- tm.gpu = a:time().real
+ local ctype = t2cpu[typename]
+ input = input:type(ctype)
+ local sbnorm = nn[moduleName](planes):type(ctype)
+ sbnorm.running_mean:normal(1, 2)
+ sbnorm.running_var:uniform(1e-3, 2)
+ sbnorm:evaluate()
+ local groundtruth = sbnorm:forward(input)
- local error = rescuda:float() - groundtruth
- mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward evaluate)')
+ input = input:type(typename)
+ local gbnorm = nn[moduleName](planes):type(typename)
+ gbnorm:evaluate()
+ gbnorm.weight = sbnorm.weight:type(typename)
+ gbnorm.bias = sbnorm.bias:type(typename)
+ gbnorm.running_mean = sbnorm.running_mean:type(typename)
+ gbnorm.running_var = sbnorm.running_var:type(typename)
+ local rescuda = gbnorm:forward(input)
+
+ local error = rescuda:double() - groundtruth:double()
+ mytester:assertlt(error:abs():max(), 3*precision_forward_type(precision_forward, typename),
+ string.format('error on state (forward evaluate) with %s', typename))
+ end
end
local function BatchNormalization_backward(moduleName, mode, inputSize, backwardFn)
assert(mode == 'training' or mode == 'evaluation', 'invalid mode')
local planes = inputSize[2]
- local tm = {}
- local title = moduleName .. '.backward ' .. table.concat(inputSize, 'x')
- times[title] = tm
- local input = torch.randn(table.unpack(inputSize))
- local gradOutput = torch.randn(table.unpack(inputSize))
- local sbnorm = nn[moduleName](planes)
- if mode == 'training' then
- sbnorm:training()
- else
- sbnorm:evaluate()
- end
- sbnorm:forward(input)
- sbnorm:zeroGradParameters()
- local groundgrad = backwardFn(sbnorm, input, gradOutput)
- local a = torch.Timer()
- for i = 1,nloop do
+ for k, typename in ipairs(typenames) do
+ local input = torch.randn(table.unpack(inputSize)):type(typename)
+ local gradOutput = torch.randn(table.unpack(inputSize)):type(typename)
+
+ local ctype = t2cpu[typename]
+ input = input:type(ctype)
+ gradOutput = gradOutput:type(ctype)
+ local sbnorm = nn[moduleName](planes):type(ctype)
+ if mode == 'training' then
+ sbnorm:training()
+ else
+ sbnorm:evaluate()
+ end
+ sbnorm:forward(input)
sbnorm:zeroGradParameters()
- groundgrad = backwardFn(sbnorm, input, gradOutput)
- end
- local groundweight = sbnorm.gradWeight
- local groundbias = sbnorm.gradBias
- tm.cpu = a:time().real
+ local groundgrad = backwardFn(sbnorm, input, gradOutput)
+ local groundweight = sbnorm.gradWeight
+ local groundbias = sbnorm.gradBias
- input = input:cuda()
- gradOutput = gradOutput:cuda()
- local gbnorm = nn[moduleName](planes):cuda()
- if mode == 'training' then
- gbnorm:training()
- else
- gbnorm:evaluate()
- end
- gbnorm.weight = sbnorm.weight:cuda()
- gbnorm.bias = sbnorm.bias:cuda()
- gbnorm:forward(input)
- gbnorm:zeroGradParameters()
- local rescuda = backwardFn(gbnorm, input, gradOutput)
- a:reset()
- for i = 1,nloop do
+ input = input:type(typename)
+ gradOutput = gradOutput:type(typename)
+ local gbnorm = nn[moduleName](planes):type(typename)
+ if mode == 'training' then
+ gbnorm:training()
+ else
+ gbnorm:evaluate()
+ end
+ gbnorm.weight = sbnorm.weight:type(typename)
+ gbnorm.bias = sbnorm.bias:type(typename)
+ gbnorm:forward(input)
gbnorm:zeroGradParameters()
- rescuda = backwardFn(gbnorm, input, gradOutput)
- end
- local weightcuda = gbnorm.gradWeight
- local biascuda = gbnorm.gradBias
- cutorch.synchronize()
- tm.gpu = a:time().real
+ local rescuda = backwardFn(gbnorm, input, gradOutput)
+ local weightcuda = gbnorm.gradWeight
+ local biascuda = gbnorm.gradBias
- local error = rescuda:float() - groundgrad
- local werror = weightcuda:float() - groundweight
- local berror = biascuda:float() - groundbias
+ local error = rescuda:double() - groundgrad:double()
+ local werror = weightcuda:double() - groundweight:double()
+ local berror = biascuda:double() - groundbias:double()
- mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ')
- mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ')
- mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ')
+ mytester:assertlt(error:abs():max(), 3*precision_backward_type(precision_backward, typename),
+ string.format('error on state (backward) with %s', typename))
+ mytester:assertlt(werror:abs():max(), 5*precision_backward_type(precision_backward, typename),
+ string.format('error on weight (backward) with %s', typename))
+ mytester:assertlt(berror:abs():max(), 5*precision_backward_type(precision_backward, typename),
+ string.format('error on bias (backward) with %s', typename))
+ end
end
local function testBatchNormalization(name, dim, k)